import os
import io
import json
import numpy as np
import pandas as pd
from PIL import Image

# Function to convert Parquet to images or JSON files
def convert_parquet_to_images_or_json(parquet_file_path, output_dir):
    # Read the Parquet file into a DataFrame
    df = pd.read_parquet(parquet_file_path)

    # Iterate over each row in the DataFrame
    for index, row in df.iterrows():
        # Extract the filename and check if it's image or JSON data
        filename = row['filename']

        # Check if 'image_data' column exists in the DataFrame
        if 'image_data' in df.columns:
            image_data = row['image_data']

            # Convert the binary data back to an image
            image = Image.open(io.BytesIO(image_data))

            # Determine the full output path
            full_output_path = os.path.join(output_dir, filename)

            # Create directories if they do not exist
            os.makedirs(os.path.dirname(full_output_path), exist_ok=True)

            # Save the image using the original filename
            image.save(full_output_path)
            print(f"Saved image: {full_output_path}")

        # Check if 'json_data' column exists in the DataFrame
        elif 'json_data' in df.columns:
            json_data = row['json_data']
            json_data = convert_numpy_to_list(json_data)
            # Determine the full output path
            full_output_path = os.path.join(output_dir, filename)

            # Create directories if they do not exist
            os.makedirs(os.path.dirname(full_output_path), exist_ok=True)

            # Save the JSON data to a file
            with open(full_output_path, 'w') as json_file:
                json.dump(json_data, json_file)
            print(f"Saved JSON: {full_output_path}")

        else:
            print(f"Unknown data format in {parquet_file_path}. Skipping row {index}.")

def convert_numpy_to_list(obj):
    """
    Recursively convert numpy arrays in a nested dictionary or list to lists.
    """
    if isinstance(obj, dict):
        return {k: convert_numpy_to_list(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_numpy_to_list(elem) for elem in obj]
    elif isinstance(obj, np.ndarray):
        return convert_numpy_to_list(obj.tolist())
    else:
        return obj

# Example usage
parquet_dir = "/ifs/root/ipa01/101/user_101002/Project/ControlVAR-main/dataset/imagenet2017/depth"
output_image_dir = "/ifs/root/ipa01/101/user_101002/Project/ControlVAR-main/dataset/imagenet2017/depth_images"
os.makedirs(output_image_dir, exist_ok=True)

# List all Parquet files in the directory
parquet_files = [f for f in os.listdir(parquet_dir) if f.endswith('.parquet')]

# Convert all Parquet files back to images or JSON
for parquet_file in parquet_files:
    parquet_file_path = os.path.join(parquet_dir, parquet_file)
    print(f"Processing {parquet_file_path}...")
    convert_parquet_to_images_or_json(parquet_file_path, output_image_dir)